{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import io\n",
    "import functools\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import soundfile\n",
    "import matplotlib\n",
    "import matplotlib.pylab as plt\n",
    "from IPython.display import display, Audio\n",
    "\n",
    "from einops import rearrange\n",
    "\n",
    "from nara_wpe.utils import stft, istft\n",
    "\n",
    "from pb_bss.distribution import CACGMMTrainer, CBMMTrainer, CWMMTrainer\n",
    "from pb_bss.permutation_alignment import DHTVPermutationAlignment, OraclePermutationAlignment\n",
    "from pb_bss.evaluation import InputMetrics, OutputMetrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def string_function(a):\n",
    "    if a.size < 50:\n",
    "        return str(a)\n",
    "    else:\n",
    "        return f'array(shape={a.shape}, dtype={a.dtype})'\n",
    "np.set_string_function(string_function)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Read some data\n",
    "You can use `soundfile.read` directly, when reading local files.\n",
    "\n",
    "Read the observation $\\mathbf y_n$ and the source images $\\mathbf x_{0, n}$ and $\\mathbf x_{1, n}$ and the noise $\\mathbf n_n$\n",
    "$$\n",
    "\\mathbf{y}_n = \\mathbf x_{0, n} + \\mathbf x_{1, n} + \\mathbf n_{n} \\in \\mathbb R ^ D\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from urllib.request import urlopen\n",
    "\n",
    "sample_rate = 8000\n",
    "\n",
    "@functools.lru_cache()\n",
    "def soundfile_read(url):\n",
    "    data, data_sample_rate = soundfile.read(io.BytesIO(urlopen(url).read()))\n",
    "    \n",
    "    assert sample_rate == data_sample_rate, (sample_rate, data_sample_rate)\n",
    "    \n",
    "    print(f'Read: {url}.\\nSample rate: {data_sample_rate}')\n",
    "    return np.ascontiguousarray(data.T)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "observation = soundfile_read(\n",
    "    \"https://github.com/fgnt/pb_test_data/raw/master/bss_data/low_reverberation/observation.wav\"\n",
    ")\n",
    "speech_image_0 = soundfile_read(\n",
    "    \"https://github.com/fgnt/pb_test_data/raw/master/bss_data/low_reverberation/speech_image_0.wav\"\n",
    ")\n",
    "speech_image_1 = soundfile_read(\n",
    "    \"https://github.com/fgnt/pb_test_data/raw/master/bss_data/low_reverberation/speech_image_1.wav\"\n",
    ")\n",
    "speech_image = np.array([speech_image_0, speech_image_1])\n",
    "noise_image = soundfile_read(\n",
    "    \"https://github.com/fgnt/pb_test_data/raw/master/bss_data/low_reverberation/noise_image.wav\"\n",
    ")\n",
    "speech_source_0 = soundfile_read(\n",
    "    \"https://github.com/fgnt/pb_test_data/raw/master/bss_data/low_reverberation/speech_source_0.wav\"\n",
    ")\n",
    "speech_source_1 = soundfile_read(\n",
    "    \"https://github.com/fgnt/pb_test_data/raw/master/bss_data/low_reverberation/speech_source_1.wav\"\n",
    ")\n",
    "speech_source = np.array([speech_source_0, speech_source_1])\n",
    "\n",
    "observation, speech_image, noise_image, speech_source"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_mask(mask, *, ax=None):\n",
    "    if ax is None:\n",
    "        ax = plt.gca()\n",
    "    image = ax.imshow(\n",
    "        mask,\n",
    "        interpolation='nearest',\n",
    "        vmin=0,\n",
    "        vmax=1,\n",
    "        origin='lower'\n",
    "    )\n",
    "    cbar = plt.colorbar(image, ax=ax)\n",
    "    return ax\n",
    "\n",
    "def plot_stft(stft_signal, *, ax=None):\n",
    "    if ax is None:\n",
    "        ax = plt.gca()\n",
    "        \n",
    "    stft_signal = np.abs(stft_signal)\n",
    "        \n",
    "    \n",
    "    stft_signal = 10 * np.log10(\n",
    "        np.maximum(stft_signal, np.max(stft_signal) / 1e6))\n",
    "    # 1e6: 60 dB is sufficient\n",
    "        \n",
    "    image = ax.imshow(\n",
    "        stft_signal,\n",
    "        interpolation='nearest',\n",
    "        origin='lower',\n",
    "    )\n",
    "    cbar = plt.colorbar(image, ax=ax)\n",
    "    cbar.set_label('Energy / dB')\n",
    "    return ax"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Calculate the STFT signals\n",
    "When the first letter is uppercase, this indicates a STFT signal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Observation = stft(observation, 512, 128)\n",
    "Speech_image = stft(speech_image, 512, 128)\n",
    "Noise_image = stft(noise_image, 512, 128)\n",
    "Observation, Speech_image, Noise_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_stft(Observation[0].T)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train a mixture model on each frequency\n",
    "\n",
    "- Instantiate a Trainer \n",
    "  - In our experiments most of the time the cACGMM yields the best results\n",
    "  - The Trainer ist statefull, becasue some distributione (e.g. complex Watson) instantiate a numeric solver\n",
    "- Call the fit function. There are three way to start the EM algorithm:\n",
    "  - initialization: np.array with shape (..., K, N)\n",
    "     - An affiliation mask, that indicate the initial probabilities.\n",
    "  - num_classes: Scalar\n",
    "     - Calculates an i.i.d. initialization mask and falls back to the case above\n",
    "  - initialization: Probabilty instance (i.e. The returned model of the fit function) (e.g. CACGMM instance) \n",
    "     - Provide a trained model.\n",
    "  - Call predict on the trained model to obtain the affiliation (i.e. posterior)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = CACGMMTrainer()\n",
    "# trainer = CWMMTrainer()\n",
    "# trainer = CBMMTrainer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Observation_mm = rearrange(Observation, 'd t f -> f t d')\n",
    "\n",
    "model = trainer.fit(\n",
    "    Observation_mm,\n",
    "    num_classes=3,\n",
    "    iterations=40,\n",
    "    inline_permutation_aligner=None\n",
    ")\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "affiliation = model.predict(Observation_mm)\n",
    "affiliation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Permutation alignment\n",
    " - The model is trained for each frequeny independent. So the permutation between all frequencies is random. (See next plot)\n",
    " - Apply a permutation alignment between the frequencies that was originally implemented from Dang Hai Tran Vu.\n",
    " - Calculate the global permutation to identify the speakers and the noise (This uses oracle information)\n",
    "   - Use as enhanced signal the observaton multiplied with the mask (ToDo: Move beamformer code to pb_bss and use beamforming instead of masking)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_mask(affiliation[:, 0, :])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pa = DHTVPermutationAlignment.from_stft_size(512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "mapping = pa.calculate_mapping(\n",
    "    rearrange(affiliation, 'f k t -> k f t')\n",
    ")\n",
    "affiliation_pa = pa.apply_mapping(\n",
    "    rearrange(affiliation, 'f k t -> k f t'),\n",
    "    mapping,\n",
    ")\n",
    "\n",
    "# Alternative to obtain affiliation_pa:\n",
    "#     affiliation_pa = pa(rearrange(affiliation, 'f k t -> k f t'))\n",
    "\n",
    "affiliation_pa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Observation, Speech_image, Noise_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "global_pa_est = rearrange(affiliation_pa, 'k f (t d) -> k d t f', d=1) * rearrange(Observation, 'd t (f k) -> k d t f', k=1)\n",
    "global_pa_est = rearrange(global_pa_est, 'k d t f -> k (d t f)')\n",
    "global_pa_reference = rearrange(np.array([*Speech_image, Noise_image]), 'k d t f -> k (d t f)')\n",
    "global_pa = OraclePermutationAlignment()\n",
    "global_permutation = global_pa.calculate_mapping(global_pa_est, global_pa_reference)\n",
    "global_permutation\n",
    "\n",
    "affiliation_pa = affiliation_pa[global_permutation]\n",
    "global_permutation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize the output\n",
    " - Display the masks, the enhanced signals and the clean signals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reference_channel = 0\n",
    "\n",
    "f, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(20, 5))\n",
    "\n",
    "plot_mask(affiliation_pa[0, :, :], ax=ax0)\n",
    "plot_mask(affiliation_pa[1, :, :], ax=ax1)\n",
    "plot_mask(affiliation_pa[2, :, :], ax=ax2)\n",
    "\n",
    "f, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(20, 5))\n",
    "\n",
    "Speech_image_0_est = Observation[reference_channel, :, :].T * affiliation_pa[0, :, :]\n",
    "Speech_image_1_est = Observation[reference_channel, :, :].T * affiliation_pa[1, :, :]\n",
    "Noise_image_est = Observation[reference_channel, :, :].T * affiliation_pa[2, :, :]\n",
    "\n",
    "plot_stft(Speech_image_0_est, ax=ax0)\n",
    "plot_stft(Speech_image_1_est, ax=ax1)\n",
    "plot_stft(Noise_image_est, ax=ax2)\n",
    "\n",
    "f, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(20, 5))\n",
    "\n",
    "plot_stft(Speech_image[0, reference_channel, :, :].T, ax=ax0)\n",
    "plot_stft(Speech_image[1, reference_channel, :, :].T, ax=ax1)\n",
    "plot_stft(Noise_image[reference_channel, :, :].T, ax=ax2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Speech_image_0_est"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "speech_image_0_est = istft(Speech_image_0_est.T, 512, 128)[..., :observation.shape[-1]]\n",
    "speech_image_1_est = istft(Speech_image_1_est.T, 512, 128)[..., :observation.shape[-1]]\n",
    "noise_image_est = istft(Noise_image_est.T, 512, 128)[..., :observation.shape[-1]]\n",
    "speech_image_0_est, speech_image_1_est, noise_image_est"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(\n",
    "    'speech_image_0_est',\n",
    "    Audio(speech_image_0_est, rate=sample_rate),\n",
    "    'speech_image_1_est',\n",
    "    Audio(speech_image_1_est, rate=sample_rate),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Metrics\n",
    "\n",
    "In pb_bss are some metrics available. Some are wrappers around external libraries (e.g. peaq, stoi) and some are implemented in this package (e.g. invasive SXR). `mir_eval` is an external metrix, that got some slightly modifications in this package.\n",
    "\n",
    "- Input metrics:\n",
    "  - Note: The input metrics are calculated for each speaker and each channel. \n",
    "    With all values given, the user can do analyse the scores in details (e.g. SNR per channel)\n",
    "- Output metrics:\n",
    "  - Note: invasive_sxr in only defined for linear enhancements."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Input metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_metric = InputMetrics(\n",
    "    observation=observation,\n",
    "    speech_source=speech_source,\n",
    "    speech_image=speech_image,\n",
    "    noise_image=noise_image,\n",
    "    sample_rate=sample_rate,\n",
    ")\n",
    "print(input_metric.as_dict().keys())\n",
    "print(input_metric.as_dict())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Display the metrics for each speaker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for k, v in input_metric.as_dict().items():\n",
    "    print(k, np.mean(v, axis=-1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Display the metrics for each channel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for k, v in input_metric.as_dict().items():\n",
    "    print(k, np.mean(v, axis=0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Display the average metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for k, v in input_metric.as_dict().items():\n",
    "    print(k, np.mean(v))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Output metric\n",
    "\n",
    "Masking is a linear enhancement, so we can calculate the parts that the speech sources and the noise contribute to the estimate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Speech_contribution = Speech_image[:, reference_channel, None, :, :] * rearrange(affiliation_pa, 'k f t -> k t f')\n",
    "Noise_contribution = Noise_image[reference_channel, :, :] * rearrange(affiliation_pa, 'k f t -> k t f')\n",
    "\n",
    "speech_contribution = istft(Speech_contribution, 512, 128)[..., :observation.shape[-1]]\n",
    "noise_contribution = istft(Noise_contribution, 512, 128)[..., :observation.shape[-1]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_metric = OutputMetrics(\n",
    "    speech_prediction=np.array([speech_image_0_est, speech_image_1_est, noise_image_est]),\n",
    "    speech_source=speech_source,\n",
    "    speech_contribution=speech_contribution,\n",
    "    noise_contribution=noise_contribution,\n",
    "    sample_rate=sample_rate,\n",
    ")\n",
    "print(output_metric.as_dict().keys())\n",
    "output_metric.as_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'{\"Score\": <19}{\"in\": <22} + {\"gain\": <20} -> out')\n",
    "print('-' * 61)\n",
    "for k, v in output_metric.as_dict().items():\n",
    "    if k == 'mir_eval_sxr_selection':\n",
    "        print(k, v)\n",
    "    else:\n",
    "        i = np.mean(input_metric.as_dict()[k])\n",
    "        o = np.mean(v)\n",
    "        g = o - i\n",
    "        print(f'{k+\":\": <19}{i: <22} + {g: <20} -> {o}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
